import torch.nn as nn
import torch
import torch.nn.functional as F
from network.norm import DynamicNorm
from torch.autograd import Variable
import torch.distributions as D
from torch.distributions import kl_divergence


class QMixFuture_net(nn.Module):
    def __init__(self, args):
        super(QMixFuture_net, self).__init__()
        self.args = args
        # 因为生成的hyper_w1需要是一个矩阵，而pytorch神经网络只能输出一个向量，
        # 所以就先输出长度为需要的 矩阵行*矩阵列 的向量，然后再转化成矩阵

        # args.n_agents是使用hyper_w1作为参数的网络的输入维度，args.qmix_hidden_dim是网络隐藏层参数个数
        # 从而经过hyper_w1得到(经验条数，args.n_agents * args.qmix_hidden_dim)的矩阵
        self.norm=DynamicNorm(args.state_shape, only_for_last_dim=True, exclude_one_hot=True, exclude_nan=True)
        self.fc_state = nn.Sequential(nn.Linear(args.state_shape+args.n_agents*args.n_actions,args.rnn_hidden_dim), nn.ReLU(inplace=True))

        self.indi_net = nn.Sequential(nn.Linear(args.state_shape,args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.indi_mu = nn.Linear(args.rnn_hidden_dim,args.rnn_hidden_dim)
        self.indi_lnsigma2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)

        # self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.infer_net_s = nn.Sequential(nn.Linear(args.rnn_hidden_dim , args.rnn_hidden_dim), nn.ReLU(inplace=True))
        self.infer_mu_s = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.infer_lnsigma_s = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.pre_rew_s=nn.Linear(args.rnn_hidden_dim, 1)
        self.rnn_state = nn.GRU(
            input_size=args.rnn_hidden_dim,
            num_layers=1,
            hidden_size=args.rnn_hidden_dim,
            batch_first=True,
        )

        self.params_dis_net = nn.Sequential(nn.Linear(args.rnn_hidden_dim,args.rnn_hidden_dim), nn.ReLU(inplace=True))


        if args.two_hyper_layers:
            self.hyper_w1 = nn.Sequential(nn.Linear(args.state_shape, args.hyper_hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(args.hyper_hidden_dim, args.n_agents * args.qmix_hidden_dim))
            # 经过hyper_w2得到(经验条数, 1)的矩阵
            self.hyper_w2 = nn.Sequential(nn.Linear(args.rnn_hidden_dim, args.qmix_hidden_dim))
        else:
            self.hyper_w1 = nn.Linear(args.state_shape, args.n_agents * args.qmix_hidden_dim)
            # 经过hyper_w2得到(经验条数, 1)的矩阵
            self.hyper_w2 = nn.Linear(args.rnn_hidden_dim, args.qmix_hidden_dim * 1)

        # hyper_w1得到的(经验条数，args.qmix_hidden_dim)矩阵需要同样维度的hyper_b1
        self.hyper_b1 = nn.Linear(args.state_shape, args.qmix_hidden_dim)
        # hyper_w2得到的(经验条数，1)的矩阵需要同样维度的hyper_b1
        self.hyper_b2 =nn.Sequential(nn.Linear(args.rnn_hidden_dim, args.qmix_hidden_dim),
                                     nn.ReLU(),
                                     nn.Linear(args.qmix_hidden_dim, 1)
                                     )

    def flipUp(self,s_in,mask):
        mask=mask.repeat(1,1,s_in.shape[-1])
        s_new=s_in.clone().flip(dims=[1])#.cpu().numpy()
        s_flip=torch.zeros_like(s_in)
        mask_flip=mask.flip(dims=[1])
        s_flip[mask.bool()]=s_new[mask_flip.bool()]
        return s_flip
    def future_s_h(self, s):
        hidden_state=torch.zeros((1,s.shape[0], self.args.rnn_hidden_dim)).to(s.device)
        mask = torch.any(s.bool(), dim=-1).float().unsqueeze(-1)
        #s_flip=s
        s_h=self.flipUp(s.clone(),mask)
        s_h=self.fc_state(s_h)
        s_h,_=self.rnn_state(s_h,hidden_state)
        s_h=self.flipUp(s_h,mask)
        return s_h.reshape(-1, s_h.shape[-1])
    def reparametrize(self, mu, logvar):
        # std = logvar.mul(0.5).exp_()
        # if self.args.cuda:
        #     eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
        #     # eps = torch.cuda.FloatTensor(mu.size()).normal_()
        # else:
        origin_size = mu.size()
        mu = mu.view(-1,mu.size(-1))
        logvar = logvar.view(-1,logvar.size(-1))

        eps = Variable(torch.randn(mu.size(0), mu.size(1))).to(mu.device) #标准正太分布
        #logvar: ln var^2
            # eps = torch.FloatTensor(mu.size()).normal_()
        z = mu + 0.001*eps*torch.exp(logvar/2)
        z=z.view(origin_size)
        return z
    def get_s_MI(self,s_h,latent_embed,episode_num):
        latent_infer = self.infer_net_s(s_h)
        infer_mu_s = self.infer_mu_s(latent_infer)  # b*a, indi_latent_dim
        infer_lnsigma_s = self.infer_lnsigma_s(latent_infer)  # b*a, indi_latent_dim
        latent_infer_embed_s = D.Normal(infer_mu_s, torch.exp(infer_lnsigma_s / 2))
        rew_latent_dis = self.reparametrize(infer_mu_s, infer_lnsigma_s)
        s_rew = self.pre_rew_s(rew_latent_dis).reshape(episode_num, -1, 1)
        r_MI = kl_divergence(latent_embed, latent_infer_embed_s).sum(dim=-1).reshape(episode_num, -1)
        return r_MI,s_rew
    def forward(self, q_values, states,u):  # states的shape为(episode_num, max_episode_len， state_shape)
        # 传入的q_values是三维的，shape为(episode_num, max_episode_len， n_agents)

        episode_num = q_values.size(0)
        s_f = torch.cat((states, u.reshape(u.shape[0], u.shape[1], -1)), dim=-1)
        q_values = q_values.view(-1, 1, self.args.n_agents)  # (episode_num * max_episode_len, 1, n_agents) = (1920,1,5)
        states = states.reshape(-1, self.args.state_shape)  # (episode_num * max_episode_len, state_shape)

        w1 = torch.abs(self.hyper_w1(states))  # (1920, 160)
        b1 = self.hyper_b1(states)  # (1920, 32)

        w1 = w1.view(-1, self.args.n_agents, self.args.qmix_hidden_dim)  # (1920, 5, 32)
        b1 = b1.view(-1, 1, self.args.qmix_hidden_dim)  # (1920, 1, 32)

        hidden = F.elu(torch.bmm(q_values, w1) + b1)  # (1920, 1, 32)


        #np.savetxt('results/mix_noloss_wb_40_33.txt',torch.cat((w2,b2),dim=-1).detach().cpu().numpy().reshape(-1,33))

        indi = self.indi_net(states)
        indi_mu = self.indi_mu(indi) # b*a, indi_latent_dim
        indi_lnsigma2 = self.indi_lnsigma2(indi)  # b*a, indi_latent_dim
        indi_latent_dis = self.reparametrize(indi_mu,indi_lnsigma2)  # b*a, indi_latent_dim
        latent_embed = D.Normal(indi_mu, torch.exp(indi_lnsigma2/2))
        latent_para = self.params_dis_net(indi_latent_dis)


        w2 = torch.abs(self.hyper_w2(latent_para))  # (1920, 32)
        b2 = self.hyper_b2(latent_para)  # (1920, 1)

        w2 = w2.view(-1, self.args.qmix_hidden_dim, 1)  # (1920, 32, 1)
        b2 = b2.view(-1, 1, 1)  # (1920, 1， 1)

        q_total1 = torch.bmm(hidden, w2) + b2  # (1920, 1, 1)
        q_total_1 = q_total1.view(episode_num, -1, 1)  # (32, 60, 1)

        s_h = self.future_s_h(s_f)
        r_MI_s, s_rew = self.get_s_MI( s_h, latent_embed, episode_num)

        q_total=q_total_1
        return q_total,s_rew, r_MI_s

